{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用Trainer API来微调模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 数据集准备和预处理：\n",
    "\n",
    "这部分就是回顾上一集的内容：\n",
    "- 通过dataset包加载数据集\n",
    "- 加载预训练模型和tokenizer\n",
    "- 定义Dataset.map要使用的预处理函数\n",
    "- 定义DataCollator来用于构造训练batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset glue (C:\\Users\\Administrator\\.cache\\huggingface\\datasets\\glue\\mrpc\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b4bdadebec1b4fa681fd5b7370f11abc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at C:\\Users\\Administrator\\.cache\\huggingface\\datasets\\glue\\mrpc\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\\cache-d7c1a56b0a079691.arrow\n",
      "Loading cached processed dataset at C:\\Users\\Administrator\\.cache\\huggingface\\datasets\\glue\\mrpc\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\\cache-4551ce60e93aa1ca.arrow\n",
      "Loading cached processed dataset at C:\\Users\\Administrator\\.cache\\huggingface\\datasets\\glue\\mrpc\\1.0.0\\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad\\cache-8e3dd97f55b2d13b.arrow\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "import datasets\n",
    "checkpoint = 'bert-base-cased'\n",
    "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
    "raw_datasets = datasets.load_dataset('glue', 'mrpc')\n",
    "\n",
    "def tokenize_function(sample):\n",
    "    return tokenizer(sample['sentence1'], sample['sentence2'], truncation=True)\n",
    "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n",
    "\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 加载我们要fine-tune的模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "不得不说，这个Huggingface很贴心，这里的warning写的很清楚。这里我们使用的是带`ForSequenceClassification`这个Head的模型，但是我们的`bert-baed-cased`虽然它本身也有自身的Head，但跟我们这里的二分类任务不匹配，所以可以看到，它的Head被移除了，使用了一个随机初始化的`ForSequenceClassification`Head。\n",
    "\n",
    "所以这里提示还说：\"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\"\n",
    "\n",
    "## 3. 使用`Trainer`来训练\n",
    "\n",
    "`Trainer`是Huggingface transformers库的一个高级API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Trainer, TrainingArguments\n",
    "\n",
    "training_args = TrainingArguments(output_dir='test_trainer')\n",
    "\n",
    "trainer = Trainer(\n",
    "    model,\n",
    "    training_args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,  # 在定义了tokenizer之后，其实这里的data_collator就不用再写了，会自动根据tokenizer创建\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们看看`TrainingArguments`和`Trainer`的参数都有些啥：\n",
    "\n",
    "- https://huggingface.co/transformers/master/main_classes/trainer.html\n",
    "- https://huggingface.co/transformers/master/main_classes/trainer.html#trainingarguments\n",
    "\n",
    "\n",
    "```python\n",
    "TrainingArguments(\n",
    "    output_dir: Union[str, NoneType] = None,\n",
    "    overwrite_output_dir: bool = False,\n",
    "    do_train: bool = False,\n",
    "    do_eval: bool = None,\n",
    "    do_predict: bool = False,\n",
    "    evaluation_strategy: transformers.trainer_utils.EvaluationStrategy = 'no',\n",
    "    prediction_loss_only: bool = False,\n",
    "    per_device_train_batch_size: int = 8,  # 默认的batch_size=8\n",
    "    per_device_eval_batch_size: int = 8,\n",
    "    per_gpu_train_batch_size: Union[int, NoneType] = None,\n",
    "    per_gpu_eval_batch_size: Union[int, NoneType] = None,\n",
    "    gradient_accumulation_steps: int = 1,\n",
    "    eval_accumulation_steps: Union[int, NoneType] = None,\n",
    "    learning_rate: float = 5e-05,\n",
    "    weight_decay: float = 0.0,\n",
    "    adam_beta1: float = 0.9,\n",
    "    adam_beta2: float = 0.999,\n",
    "    adam_epsilon: float = 1e-08,\n",
    "    max_grad_norm: float = 1.0,\n",
    "    num_train_epochs: float = 3.0,   # 默认跑3轮\n",
    "    ...\n",
    "```\n",
    "\n",
    "```python\n",
    "Trainer(\n",
    "    model: Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None,\n",
    "    args: transformers.training_args.TrainingArguments = None,\n",
    "    data_collator: Union[DataCollator, NoneType] = None,\n",
    "    train_dataset: Union[torch.utils.data.dataset.Dataset, NoneType] = None,\n",
    "    eval_dataset: Union[torch.utils.data.dataset.Dataset, NoneType] = None,\n",
    "    tokenizer: Union[ForwardRef('PreTrainedTokenizerBase'), NoneType] = None,\n",
    "    model_init: Callable[[], transformers.modeling_utils.PreTrainedModel] = None,\n",
    "    compute_metrics: Union[Callable[[transformers.trainer_utils.EvalPrediction], Dict], NoneType] = None,\n",
    "    callbacks: Union[List[transformers.trainer_callback.TrainerCallback], NoneType] = None,\n",
    "    optimizers: Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),  # 默认会使用AdamW\n",
    ")\n",
    "Docstring:     \n",
    "Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.\n",
    "```\n",
    "\n",
    "可见，这个`Trainer`把所有训练中需要考虑的参数、设计都包括在内了，我们可以在这里指定训练验证集、data_collator、metrics、optimizer，并通过`TrainingArguments`来提供各种超参数。\n",
    "\n",
    "默认情况下，`Trainer`和`TrainingArguments`会使用：\n",
    "- batch size=8\n",
    "- epochs = 3\n",
    "- AdamW优化器\n",
    "\n",
    "\n",
    "定义好之后，直接使用`.train()`来启动训练："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "        </style>\n",
       "      \n",
       "      <progress value='1377' max='1377' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1377/1377 06:20, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.539400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.319400</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1377, training_loss=0.35569445984728887, metrics={'train_runtime': 383.0158, 'train_samples_per_second': 3.595, 'total_flos': 530185443455520, 'epoch': 3.0})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后我们用`Trainer`来预测：\n",
    "\n",
    "`trainer.predict()`函数处理的结果是一个named_tuple，类似一个字典，包含三个属性：predictions, label_ids, metrics\n",
    "\n",
    "注意，这里的三个属性：\n",
    "- `predictions`实际上就是logits\n",
    "- `label_ids`不是预测出来的id，而是数据集中自带的ground truth的label id，因此如果输入的数据集中没给标签，这里也不会输出\n",
    "- `metrics`，也是只有输入的数据集中提供了`label_ids`才会输出metrics，包括loss之类的指标\n",
    "\n",
    "其中`metrics`中还可以包含我们自定义的字段，我们需要在定义`Trainer`的时候给定`compute_metrics`参数。\n",
    "\n",
    "文档参考： https://huggingface.co/transformers/master/main_classes/trainer.html#transformers.Trainer.predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "        </style>\n",
       "      \n",
       "      <progress value='51' max='51' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [51/51 00:03]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(408, 2)\n",
      "(408,)\n",
      "{'eval_loss': 0.7387174963951111, 'eval_runtime': 3.2872, 'eval_samples_per_second': 124.117}\n"
     ]
    }
   ],
   "source": [
    "predictions = trainer.predict(tokenized_datasets['validation'])\n",
    "print(predictions.predictions.shape)  # logits\n",
    "# array([[-2.7887206,  3.1986978],\n",
    "#       [ 2.5258656, -1.832253 ], ...], dtype=float32)\n",
    "print(predictions.label_ids.shape) # array([1, 0, 0, 1, 0, 1, 0, 1, 1, 1, ...], dtype=int64)\n",
    "print(predictions.metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后就可以用preds和labels来计算一些相关的metrics了。\n",
    "\n",
    "Huggingface `datasets`里面可以直接导入跟数据集相关的metrics："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.8455882352941176, 'f1': 0.8911917098445595}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_metric\n",
    "\n",
    "preds = np.argmax(predictions.predictions, axis=-1)\n",
    "\n",
    "metric = load_metric('glue', 'mrpc')\n",
    "metric.compute(predictions=preds, references=predictions.label_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "metric，glue type的文档：\n",
    "```\n",
    "Args:\n",
    "    predictions: list of predictions to score.\n",
    "        Each translation should be tokenized into a list of tokens.\n",
    "    references: list of lists of references for each translation.\n",
    "        Each reference should be tokenized into a list of tokens.\n",
    "Returns: depending on the GLUE subset, one or several of:\n",
    "    \"accuracy\": Accuracy\n",
    "    \"f1\": F1 score\n",
    "    \"pearson\": Pearson Correlation\n",
    "    \"spearmanr\": Spearman Correlation\n",
    "    \"matthews_correlation\": Matthew Correlation\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.构建`Trainer`中的`compute_metrics`函数\n",
    "\n",
    "Let’s see how we can build a useful compute_metrics function and use it the next time we train. The function must take an EvalPrediction object (which is a named tuple with a predictions field and a label_ids field) and will return a dictionary mapping strings to floats (the strings being the names of the metrics returned, and the floats their values). \n",
    "\n",
    "前面我们注意到`Trainer`的参数中，可以提供一个`compute_metrics`函数，用于输出我们希望有的一些指标。\n",
    "\n",
    "这个`compute_metrics`有一些输入输出的要求：\n",
    "- 输入：是一个`EvalPrediction`对象，是一个named tuple，需要有至少`predictions`和`label_ids`两个字段；经过查看源码，这里的predictions，**就是logits**\n",
    "- 输出：一个字典，包含各个metrics和对应的数值。\n",
    "\n",
    "源码地址： https://huggingface.co/transformers/master/_modules/transformers/trainer.html#Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_metric\n",
    "def compute_metrics(eval_preds):\n",
    "    metric = load_metric(\"glue\", \"mrpc\")\n",
    "    logits, labels = eval_preds.predictions, eval_preds.label_ids\n",
    "    # 上一行可以直接简写成：\n",
    "    # logits, labels = eval_preds  因为它相当于一个tuple\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "总结一下这个过程：\n",
    "\n",
    "- 首先我们定义了一个`compute_metrics`函数，交给`Trainer`；\n",
    "- `Trainer`训练模型，模型会对样本计算，产生 predictions (logits)；\n",
    "- `Trainer`再把 predictions 和数据集中给定的 label_ids 打包成一个对象，发送给`compute_metrics`函数；\n",
    "- `compute_metrics`函数计算好相应的 metrics 然后返回。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 看看带上了 compute_metrics 之后的训练："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "        </style>\n",
       "      \n",
       "      <progress value='1377' max='1377' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1377/1377 06:51, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>F1</th>\n",
       "      <th>Runtime</th>\n",
       "      <th>Samples Per Second</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.329815</td>\n",
       "      <td>0.867647</td>\n",
       "      <td>0.903571</td>\n",
       "      <td>5.873300</td>\n",
       "      <td>69.467000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.497900</td>\n",
       "      <td>0.600649</td>\n",
       "      <td>0.845588</td>\n",
       "      <td>0.897227</td>\n",
       "      <td>17.319700</td>\n",
       "      <td>23.557000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.283200</td>\n",
       "      <td>0.605053</td>\n",
       "      <td>0.872549</td>\n",
       "      <td>0.910345</td>\n",
       "      <td>9.244300</td>\n",
       "      <td>44.135000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1377, training_loss=0.32063739751678666, metrics={'train_runtime': 414.1719, 'train_samples_per_second': 3.325, 'total_flos': 530351810395680, 'epoch': 3.0})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_args = TrainingArguments(output_dir='test_trainer', evaluation_strategy='epoch')\n",
    "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)  # new model\n",
    "trainer = Trainer(\n",
    "    model,\n",
    "    training_args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,  # 在定义了tokenizer之后，其实这里的data_collator就不用再写了，会自动根据tokenizer创建\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可见，带上了`compute_metircs`函数之后，在Trainer训练过程中，会把增加的metric也打印出来，方便我们时刻连接训练的进展。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
